{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "04205.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/yananma/5_programs_per_day/blob/master/04205.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ucpkjwm7o0Ph",
        "colab_type": "text"
      },
      "source": [
        "## 4.2 模型参数的访问、初始化和共享"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RHsET7rkp3g1",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 101
        },
        "outputId": "44bc90ab-ccb0-4e4d-aaa1-58f598bfdb41"
      },
      "source": [
        "import torch \n",
        "from torch import nn \n",
        "from torch.nn import init \n",
        "\n",
        "net = nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 1))\n",
        "net "
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Sequential(\n",
              "  (0): Linear(in_features=4, out_features=3, bias=True)\n",
              "  (1): ReLU()\n",
              "  (2): Linear(in_features=3, out_features=1, bias=True)\n",
              ")"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 2
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YZ0T--HrswR9",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 50
        },
        "outputId": "bdafe9d4-d961-4dfb-ead7-baf8fda17813"
      },
      "source": [
        "X = torch.rand(2, 4)\n",
        "net(X)"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[0.2371],\n",
              "        [0.2078]], grad_fn=<AddmmBackward>)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 3
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cp5o-LPCs7eI",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "f86b9cbb-fb73-44c6-bddc-005b4c428aeb"
      },
      "source": [
        "Y = net(X).sum()\n",
        "Y"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor(0.4448, grad_fn=<SumBackward0>)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 4
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "C23-gTiXs_nP",
        "colab_type": "text"
      },
      "source": [
        "### 4.2.1 访问模型参数"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "r6fWZgXhtFzK",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "ea202600-24e2-481f-e001-952e84a14f1e"
      },
      "source": [
        "print(type(net.named_parameters()))"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "<class 'generator'>\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1o7qEv4LtNLF",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 84
        },
        "outputId": "2fff4d9e-e077-4c61-c178-05d059e02892"
      },
      "source": [
        "for name, param in net.named_parameters():\n",
        "    print(name, param.size())"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "0.weight torch.Size([3, 4])\n",
            "0.bias torch.Size([3])\n",
            "2.weight torch.Size([1, 3])\n",
            "2.bias torch.Size([1])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CBPmV2betXYU",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 50
        },
        "outputId": "28ba13e4-d8fc-472d-e044-5548e8a3f481"
      },
      "source": [
        "for name, param in net[0].named_parameters():\n",
        "    print(name, param.size(), type(param))"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "weight torch.Size([3, 4]) <class 'torch.nn.parameter.Parameter'>\n",
            "bias torch.Size([3]) <class 'torch.nn.parameter.Parameter'>\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "H-ltr5zGtyV4",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "42589637-e346-4f80-9ed1-b1bdfbaa8975"
      },
      "source": [
        "class MyModel(nn.Module):\n",
        "    def __init__(self, **kwargs):\n",
        "        super(MyModel, self).__init__(**kwargs)\n",
        "        self.weight1 = nn.Parameter(torch.rand(20, 20))\n",
        "        self.weight2 = torch.rand(20, 20)\n",
        "\n",
        "    def forward(self, x):\n",
        "        pass \n",
        "\n",
        "\n",
        "n = MyModel()\n",
        "for name, param in n.named_parameters():\n",
        "    print(name)"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "weight1\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xoIVOsO6ubCB",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 67
        },
        "outputId": "65c2553d-6929-4f39-b737-123bac37541b"
      },
      "source": [
        "weight_0 = list(net[0].parameters())[0]\n",
        "weight_0.data"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[ 0.3120,  0.3832, -0.2590, -0.2786],\n",
              "        [-0.3275, -0.3611, -0.1754, -0.3223],\n",
              "        [-0.2355,  0.1803, -0.0546, -0.4812]])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 9
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rtF6Zl7gvEJm",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "60aee6a3-fe4d-4249-d0c6-125985ebb103"
      },
      "source": [
        "print(weight_0.grad)"
      ],
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "None\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iPhVG5iwvGaA",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 67
        },
        "outputId": "ed9b1053-c42b-4cfb-db14-e2bb310cd180"
      },
      "source": [
        "Y.backward()\n",
        "weight_0.grad"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[0.2141, 0.1815, 0.1700, 0.1521],\n",
              "        [0.0000, 0.0000, 0.0000, 0.0000],\n",
              "        [0.0000, 0.0000, 0.0000, 0.0000]])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 11
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vfz-2DmQvO_P",
        "colab_type": "text"
      },
      "source": [
        "### 4.2.2 初始化模型参数"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QJumeTWavhp1",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 84
        },
        "outputId": "d607d1da-c425-4457-eb3a-dab294bf8fe4"
      },
      "source": [
        "for name, param in net.named_parameters():\n",
        "    if 'weight' in name:\n",
        "        init.normal_(param, mean=0, std=0.01)\n",
        "        print(name, param.data)"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "0.weight tensor([[ 0.0114,  0.0106,  0.0060,  0.0084],\n",
            "        [ 0.0190, -0.0069,  0.0128,  0.0150],\n",
            "        [-0.0085, -0.0034, -0.0051, -0.0029]])\n",
            "2.weight tensor([[ 0.0094, -0.0043, -0.0024]])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xh34rS_Fv4B8",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 50
        },
        "outputId": "afef8c3d-1258-4858-e283-fa41a79f7b9c"
      },
      "source": [
        "for name, param in net.named_parameters():\n",
        "    if 'bias' in name:\n",
        "        init.constant_(param, val=0)\n",
        "        print(name, param.data)"
      ],
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "0.bias tensor([0., 0., 0.])\n",
            "2.bias tensor([0.])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WzFySgThwJfF",
        "colab_type": "text"
      },
      "source": [
        "### 4.2.3 自定义初始化方法"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dc4SUlT6y1ml",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def normal_(tensor, mean=0, std=1):\n",
        "    with torch.no_grad():\n",
        "        return tensor.normal_(mean, std)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YW5vodDOzAxb",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 84
        },
        "outputId": "9cf01045-479c-46f4-fe91-f8a4a903cfca"
      },
      "source": [
        "def init_weight_(tensor):\n",
        "    with torch.no_grad():\n",
        "        tensor.uniform_(-10, 10)\n",
        "        tensor *= (tensor.abs() >= 5).float()\n",
        "\n",
        "\n",
        "for name, param in net.named_parameters():\n",
        "    if 'weight' in name:\n",
        "        init_weight_(param)\n",
        "        print(name, param.data)"
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "0.weight tensor([[ 6.3236,  6.5301,  0.0000, -0.0000],\n",
            "        [-5.0343,  0.0000,  6.4923, -0.0000],\n",
            "        [ 0.0000, -9.0378,  9.8344,  0.0000]])\n",
            "2.weight tensor([[-5.5816,  5.1283,  8.0439]])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ueMfqMJ8zfg8",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 50
        },
        "outputId": "36e905ad-2e13-447f-e122-8a56d1bd24a7"
      },
      "source": [
        "for name, param in net.named_parameters():\n",
        "    if 'bias' in name:\n",
        "        param.data += 1 \n",
        "        print(name, param.data)"
      ],
      "execution_count": 16,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "0.bias tensor([1., 1., 1.])\n",
            "2.bias tensor([1.])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PoubEtCtz9cx",
        "colab_type": "text"
      },
      "source": [
        "### 4.2.4 共享参数模型"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6koXjqeD0e2v",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 84
        },
        "outputId": "18b0b9a2-348b-4048-8ac8-4a4649131ef9"
      },
      "source": [
        "linear = nn.Linear(1, 1, bias=False)\n",
        "net = nn.Sequential(linear, linear)\n",
        "net "
      ],
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Sequential(\n",
              "  (0): Linear(in_features=1, out_features=1, bias=False)\n",
              "  (1): Linear(in_features=1, out_features=1, bias=False)\n",
              ")"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 17
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SeQod7_q0msF",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "484ba418-3a86-4a1f-f5a2-28a6ad528ea5"
      },
      "source": [
        "for name, param in net.named_parameters():\n",
        "    init.constant_(param, val=3)\n",
        "    print(name, param.data)"
      ],
      "execution_count": 19,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "0.weight tensor([[3.]])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8WqJaJQL0w8x",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "536138d4-ca6b-4ba9-938e-b5ad5ffff564"
      },
      "source": [
        "id(net[0]) == id(net[1])"
      ],
      "execution_count": 20,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "True"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 20
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "n7gTz7M61UF-",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "8f288baa-c1a0-46fe-f827-ea9579961900"
      },
      "source": [
        "id(net[0].weight) == id(net[1].weight)"
      ],
      "execution_count": 21,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "True"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 21
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Z9FprhaO1bXX",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "6a27bd1e-4f40-40ac-e6b6-eb89be835510"
      },
      "source": [
        "x = torch.ones(1, 1)\n",
        "y = net(x).sum()\n",
        "y"
      ],
      "execution_count": 22,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor(9., grad_fn=<SumBackward0>)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 22
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3gGCeb562adL",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "0b1403eb-ab71-45f9-a717-66ebac4a71bc"
      },
      "source": [
        "y.backward()\n",
        "net[0].weight.grad"
      ],
      "execution_count": 23,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[6.]])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 23
        }
      ]
    }
  ]
}